import os
import json
import shutil
import logging

import tensorflow as tf
from conlleval import return_report
from data_utils import load_data_test


models_path = "./models"
eval_path = "./evaluation"
eval_temp = os.path.join(eval_path, "temp")
eval_script = os.path.join(eval_path, "conlleval")


def get_logger(log_file):
    logger = logging.getLogger(log_file)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    ch.setFormatter(formatter)
    fh.setFormatter(formatter)
    logger.addHandler(ch)
    logger.addHandler(fh)
    return logger

def test_as(results, path):
    """
    Run perl script to evaluate model
    """
    output_file = os.path.join(path, "as_predict.utf8")
    with open(output_file, "w", encoding='utf8') as f:
        to_write = []
        for block in results:
            for line in block:
                to_write.append(line + "\n")
            to_write.append("\n")

        f.writelines(to_write)
    eval_lines = return_report(output_file)
    return eval_lines


def print_config(config, logger):
    """
    Print configuration of the model
    """
    for k, v in config.items():
        logger.info("{}:\t{}".format(k.ljust(15), v))


def make_path(params):
    """
    Make folders for training and evaluation
    """
    if not os.path.isdir(params.result_path):
        os.makedirs(params.result_path)
    if not os.path.isdir(params.ckpt_path):
        os.makedirs(params.ckpt_path)
    if not os.path.isdir("log"):
        os.makedirs("log")


def clean(params):
    """
    Clean current folder
    remove saved model and training log
    """
    if os.path.isfile(params.vocab_file):
        os.remove(params.vocab_file)

    if os.path.isfile(params.map_file):
        os.remove(params.map_file)

    if os.path.isdir(params.ckpt_path):
        shutil.rmtree(params.ckpt_path)

    if os.path.isdir(params.summary_path):
        shutil.rmtree(params.summary_path)

    if os.path.isdir(params.result_path):
        shutil.rmtree(params.result_path)

    if os.path.isdir("log"):
        shutil.rmtree("log")

    if os.path.isdir("__pycache__"):
        shutil.rmtree("__pycache__")

    if os.path.isfile(params.config_file):
        os.remove(params.config_file)

    if os.path.isfile(params.vocab_file):
        os.remove(params.vocab_file)


def save_config(config, config_file):
    """
    Save configuration of the model
    parameters are stored in json format
    """
    with open(config_file, "w", encoding="utf8") as f:
        json.dump(config, f, ensure_ascii=False, indent=4)


def load_config(config_file):
    """
    Load configuration of the model
    parameters are stored in json format
    """
    with open(config_file, encoding="utf8") as f:
        return json.load(f)


def convert_to_text(line):
    """
    Convert conll data to text
    """
    to_print = []
    for item in line:

        try:
            if item[0] == " ":
                to_print.append(" ")
                continue
            word, gold, tag = item.split(" ")
            if tag[0] in "SB":
                to_print.append("[")
            to_print.append(word)
            if tag[0] in "SE":
                to_print.append("@" + tag.split("-")[-1])
                to_print.append("]")
        except:
            print(list(item))
    return "".join(to_print)


def save_model(sess, model, path, logger, step):
    checkpoint_path = os.path.join(path, "as.ckpt")
    model.saver.save(sess, checkpoint_path, global_step=step)
    logger.info("model saved")


def create_model(session, Model_class, path, load_vec, config, id_to_char, logger):
    # create model, reuse parameters if exists
    model = Model_class(config)

    ckpt = tf.train.get_checkpoint_state(path)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        logger.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        logger.info("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
        if config["pre_emb"]:
            emb_weights = session.run(model.char_lookup.read_value())
            emb_weights = load_vec(config["emb_file"],id_to_char, config["char_dim"], emb_weights)
            session.run(model.char_lookup.assign(emb_weights))
            logger.info("Load pre-trained embedding.")
    return model


def result_to_json(string, tags):
    item = {"string": string, "entities": []}
    entity_name = ""
    entity_start = 0
    idx = 0
    for char, tag in zip(string, tags):
        if tag[0] == "S":
            item["entities"].append({"word": char, "start": idx, "end": idx+1, "type":tag[2:]})
        elif tag[0] == "B":
            entity_name += char
            entity_start = idx
        elif tag[0] == "I":
            entity_name += char
        elif tag[0] == "E":
            entity_name += char
            item["entities"].append({"word": entity_name, "start": entity_start, "end": idx + 1, "type": tag[2:]})
            entity_name = ""
        else:
            entity_name = ""
            entity_start = idx
        idx += 1
    return item

def train_as(results, path,i):
    add = str(i)+".utf8"
    output_file = os.path.join(path, add)
    with open(output_file, "w", encoding='utf8') as f:
        to_write = []
        for block in results:
            for line in block:
                to_write.append(line + "\n")
            to_write.append("\n")

        f.writelines(to_write)

def test_best_as(results, path):
    """
    Run perl script to evaluate model
    """
    add = "predict.utf8"
    output_file = os.path.join(path, add)
    with open(output_file, "w", encoding='utf8') as f:
        to_write = []
        for block in results:
            for line in block:
                to_write.append(line + "\n")
            to_write.append("\n")

        f.writelines(to_write)
    # eval_lines = return_report(output_file)

def generate_start_index(context,answer_start_index,answer):
    update_index = []
    for i in range(len(context)):
        update_index_part = []
        context_part = context[i]
        answer_part = answer[i]
        answer_start_index_part = answer_start_index[i]
        for j in range(len(answer_start_index_part)):
            ind = answer_start_index_part[j]
            if ind == 0:
                update_index_part.append(0)
            else:
                index = 0
                if ind > len(context_part):
                    print("!!!!!!!!!!!!!!!!!!!!!")
                    print(context_part)
                    print(answer_part[j])
                    print(ind)
                for m in range(0,ind):
                    index = index+len(context_part[m])+1
                update_index_part.append(index)
        update_index.append(update_index_part)

    return update_index

def generate_repositioning_data(path,context, query, answer,update_index,i):
    add = str(i)+"_repositioning.json"
    output_file = os.path.join(path, add)
    f = open(output_file,"w",encoding="utf-8")

    all = {}
    dic = []
    id_ = 0
    for i in range(len(context)):
        context_part = context[i]
        query_part = query[i]
        answer_part = answer[i]
        update_index_part = update_index[i]
        dic_in = {}
        dic_in["title"] = " "
        par = []
        par_in = {}
        par_in["context"] = " ".join(context_part)
        qas = []
        for j in range(len(query_part)):
            qas_in = {}
            ans = []
            ans_in = {}
            ans_in["answer_start"] = update_index_part[j]
            ans_in["text"] = " ".join(answer_part[j])
            ans.append(ans_in)
            qas_in["answers"] = ans
            qas_in["question"] = " ".join(query_part[j])
            qas_in["id"] = str(id_)
            id_ = id_+1
            qas.append(qas_in)
        par_in["qas"] = qas
        par.append(par_in)
        dic_in["paragraphs"] = par
        dic.append(dic_in)
    all["data"] = dic
    all["version"] = str(1.1)
    json.dump(all,f)

def squad_test_data(context,query,path):
    add_squad = "repositioning_test.json"
    squad_file = os.path.join(path, add_squad)
    f = open(squad_file,"w",encoding="utf-8")
    add_que2con = "que2con.json"
    que2con_file = os.path.join(path, add_que2con)
    f1 = open(que2con_file, "w", encoding="utf-8")
    all = {}
    dic = []
    id_ = 0
    que2con = {}
    for i in range(len(context)):
        context_part = " ".join(context[i])
        query_part = query[i]
        dic_in = {}
        dic_in["title"] = " "
        par = []
        par_in = {}
        par_in["context"] = context_part
        qas = []
        for j in range(len(query_part)):
            qas_in = {}
            ans = []
            ans_in = {}
            ans_in["answer_start"] = ""
            ans_in["text"] = ""
            ans.append(ans_in)
            qas_in["answers"] = ans
            qas_in["question"] = " ".join(query_part[j])
            qas_in["id"] = str(id_)
            que2con[id_] = i
            id_ = id_+1
            qas.append(qas_in)


        par_in["qas"] = qas
        par.append(par_in)
        dic_in["paragraphs"] = par
        dic.append(dic_in)
    all["data"] = dic
    all["version"] = str(1.1)
    json.dump(all,f)
    json.dump(que2con,f1)

def generate_repositioning_test_data(path):

    context,query = load_data_test(path)
    if len(context) != 0:
        squad_test_data(context,query,path)